import os
from pathlib import Path
import math
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

try:
    import wandb
except ImportError:
    print('wandb not available')
try:
    import ray
except ImportError:
    print('ray not available')

import data
import losses
import models
import ap


def cmdline_args():
    parser = ArgumentParser()
    # experiment config
    parser.add_argument('--project', default=None)
    parser.add_argument('--name', default='default')
    # dataset config
    parser.add_argument('--dataset', choices=['random', 'clevr', 'mnist'], default='random')
    parser.add_argument('--loss', choices=['hungarian', 'chamfer'], default='hungarian')
    parser.add_argument('--set_size', type=int, default=10)
    parser.add_argument('--set_dim', type=int, default=2)
    parser.add_argument('--dataset_size', type=int, default=64000)
    parser.add_argument('--clevr_path', default='clevr')
    parser.add_argument('--clevr_image_input', action='store_true')
    parser.add_argument('--clevr_image_size', type=int, default=128)
    # training config
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--n_gpus', type=int, default=1)
    parser.add_argument('--lr_drop_epoch', type=int, default=None)
    parser.add_argument('--checkpoint_path', default='checkpoints')
    parser.add_argument('--num_data_workers', type=int, default=0)
    parser.add_argument('--num_ray_workers', type=int, default=0)
    # model config
    parser.add_argument('--model', default='idspn', choices=['idspn', 'dspn', 'slot'])
    parser.add_argument('--latent_dim', type=int, default=512)
    parser.add_argument('--hidden_dim', type=int, default=512)
    parser.add_argument('--input_encoder', default='rnfs', choices=['fs','rnfs'])
    # idspn config
    parser.add_argument('--decoder_encoder', default='fs', choices=['fs','rnfs'])
    parser.add_argument('--decoder_lr', type=float, default=1.0)
    parser.add_argument('--decoder_iters', type=int, default=20)
    parser.add_argument('--decoder_momentum', type=float, default=0.9)
    parser.add_argument('--decoder_val_iters', type=int, default=None)
    parser.add_argument('--decoder_grad_clip', type=float)
    parser.add_argument('--decoder_it_schedule', action='store_true')
    parser.add_argument('--decoder_starting_set', action='store_true')
    # wandb config
    parser.add_argument('--no_wandb', dest='use_wandb', action='store_false')
    # eval config
    parser.add_argument('--progress_num_examples', type=int, default=0)
    parser.add_argument('--progress_path', default='progress')
    parser.add_argument('--eval_checkpoint', default=None)
    parser.add_argument('--test_after_training', action='store_true')
    parser.add_argument('--save_predictions', type=str)

    args = parser.parse_args()
    
    if args.dataset == 'random':
        assert args.set_size > 0
        assert args.set_dim > 0
    elif args.dataset == 'clevr':
        args.set_size = 10
        args.set_dim = 19
    elif args.dataset == 'mnist':
        args.set_size = 342
        args.set_dim = 3

    if args.project is None:
        if args.dataset == 'random':
            args.project = f'random-dim{args.set_dim}-size{args.set_size}'
        elif args.dataset == 'clevr':
            args.project = 'clevr-' + ('images' if args.clevr_image_input else 'autoencode')
        elif args.dataset == 'mnist':
            args.project = 'mnist'
    return args


class SetPredictionModel(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.save_hyperparameters(args)
        if 'dspn' in args.model:
            self.net = models.DSPNModel(
                d_in=self.hparams.set_dim,
                d_hid=self.hparams.hidden_dim,
                d_latent=self.hparams.latent_dim,
                set_size=self.hparams.set_size,

                input_encoder=self.hparams.input_encoder,
                decoder_encoder=self.hparams.decoder_encoder,
                lr=args.decoder_lr,
                iters=args.decoder_iters,
                momentum=args.decoder_momentum,
                grad_clip=args.decoder_grad_clip,
                use_starting_set=args.decoder_starting_set,

                image_input=self.hparams.clevr_image_input,
                image_size=self.hparams.clevr_image_size,
                implicit=self.hparams.model == 'idspn',
            )
        else:
            self.net = models.SlotAttentionModel(
                d_in=self.hparams.set_dim,
                d_hid=self.hparams.hidden_dim,
                set_size=args.set_size,
            )

        if args.dataset == 'random':
            self.trainset = data.Objects(size=args.dataset_size, cardinality=args.set_size, dim=args.set_dim)
            self.valset = data.Objects(size=args.dataset_size // 10, cardinality=args.set_size, dim=args.set_dim)
        elif args.dataset.startswith('clevr'):
            self.trainset = data.CLEVR(args.clevr_path, 'train', image_input=args.clevr_image_input, image_size=self.args.clevr_image_size)
            self.valset = data.CLEVR(args.clevr_path, 'val', image_input=args.clevr_image_input, image_size=self.args.clevr_image_size)
        elif args.dataset == 'mnist':
            self.trainset = data.MNISTSetMasked(train=True)
            self.valset = data.MNISTSetMasked(train=False)
        self.ap_prefix = ''

    def forward(self, x):
        input, gt_output = x
        output, set_grad = self.net(input)
        return output, gt_output, set_grad

    def training_step(self, batch, batch_nb):
        output, gt_output, set_grad = self(batch)

        if self.args.loss == 'hungarian':
            loss = losses.hungarian_loss(output, gt_output, num_workers=self.args.num_ray_workers).mean(0)
        else:
            loss = losses.chamfer_loss(output, gt_output).mean(0)
        grad_norm = set_grad.norm(dim=[1, 2]).mean()

        log_dict = dict(loss=loss, grad_norm=grad_norm)
        self.log_dict({k+"/train": v for k,v in log_dict.items()})

        return loss

    def validation_step(self, batch, batch_idx):
        output, gt_output, set_grad = self(batch)

        if self.args.loss == 'hungarian':
            loss = losses.hungarian_loss(output, gt_output, num_workers=self.args.num_ray_workers).mean(0)
        else:
            loss = losses.chamfer_loss(output, gt_output).mean(0)
        grad_norm = set_grad.norm(dim=[1, 2]).mean()
            
        if batch_idx == 0 and self.args.progress_num_examples > 0:
            path = os.path.join(self.args.progress_path, self.args.project, self.args.name, f"{self.global_step}.png")
            self.plot_pointset(output, gt_output, Path(path), n_examples=self.args.progress_num_examples)

        log_dict = dict(loss=loss, grad_norm=grad_norm)
        self.log_dict({k+"/val": v for k,v in log_dict.items()})

        if self.args.dataset == 'clevr':
            thresholds = [float('inf'), 1, 0.5, 0.25, 0.125, 0.0625]
            aps = ap.compute_ap(gt_output, output, thresholds)
            self.log_dict({f'{self.ap_prefix}ap/{threshold}': ap for threshold, ap in zip(thresholds, aps)})

        return loss
    
    def test_step(self, batch, batch_idx):
        self.ap_prefix = 'test/'
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.args.lr)
        if self.args.lr_drop_epoch is not None:
            scheduler = {
                'scheduler': torch.optim.lr_scheduler.StepLR(opt, step_size=self.args.lr_drop_epoch)
            }
            return [opt], [scheduler]
        return opt

    def train_dataloader(self):
        return DataLoader(
            self.trainset, 
            batch_size=self.args.batch_size, 
            shuffle=True, 
            num_workers=self.args.num_data_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valset, 
            batch_size=self.args.batch_size, 
            shuffle=False,
            num_workers=self.args.num_data_workers,
        )
    
    def on_train_epoch_start(self) -> None:
        if self.args.decoder_it_schedule:
            it = int(0.5 * self.args.decoder_iters)
            if self.current_epoch >= 0.5 * self.args.epochs:
                it = self.args.decoder_iters
            self.net.dspn.iters = it
        else:
            self.net.dspn.iters = self.args.decoder_iters

    def on_val_epoch_start(self) -> None:
        self.net.dspn.iters = self.args.decoder_val_iters or self.args.decoder_iters
            
    def plot_pointset(self, pred, target, filename, n_examples):
        n_rows = n_cols = math.ceil(n_examples ** 0.5)
        fig, axs = plt.subplots(n_rows, n_cols, squeeze=False, figsize=(15,15))
        
        pred = pred.cpu().transpose(1, 2)
        target = target.cpu().transpose(1, 2)

        if self.args.dataset == 'mnist':
            lim = 0, 1
        else:
            lim = -3, 3

        for i in range(n_examples):
            ax = axs[i // n_cols, i % n_cols]
            if self.args.dataset == 'mnist':
                ax.scatter(target[i, 1], 1 - target[i, 0], marker='o', s=5**2)
                ax.scatter(pred[i, 1], 1 - pred[i, 0], marker='x', s=5**2)
            else:
                ax.scatter(target[i, 0], target[i, 1], marker='o', s=5**2)
                ax.scatter(pred[i, 0], pred[i, 1], marker='x', s=5**2)
            ax.axis("equal")
            ax.set_xlim(*lim)
            ax.set_ylim(*lim)
        
        fig.tight_layout()
        filename.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(filename)


def train(args):
    model = SetPredictionModel(args)

    if args.num_ray_workers > 0:
        ray.init(num_cpus=args.num_ray_workers, include_dashboard=False)

    if args.use_wandb:
        wandb.init(
            name=args.name,
            project=args.project,
            reinit=False,
            # settings=wandb.Settings(start_method="fork"),
        )
        logger = WandbLogger(log_model=True)
        logger.watch(model.net)
        wandb.config.update(args)

    checkpoint_path = os.path.join(args.checkpoint_path, args.project, args.name)
    trainer = pl.Trainer(
        max_epochs=args.epochs,
        limit_val_batches=0.1 if args.dataset == 'clevr' else 1.0,
        gpus=args.n_gpus,
        num_nodes=1,
        logger=logger if args.use_wandb else None,
        callbacks=[
            ModelCheckpoint(dirpath=checkpoint_path),
        ],
    )

    trainer.fit(model)
    if args.test_after_training:
        test(args, model, trainer)
    return model


def test(args, model=None, trainer=None):
    if model is None:
        model = SetPredictionModel.load_from_checkpoint(checkpoint_path=args.eval_checkpoint, args=args)
    if trainer is None:
        trainer = pl.Trainer(gpus=args.n_gpus, num_nodes=1)
    trainer.limit_val_batches = 1.0
    if not args.save_predictions:
        trainer.test(model, model.val_dataloader())
    else:
        outputs = trainer.predict(model, model.val_dataloader())
        torch.save([[o.cpu().detach() for o in output] for output in outputs], args.save_predictions)


def main():
    args = cmdline_args()
    pl.seed_everything(args.seed)

    if args.eval_checkpoint is None:
        train(args)
    else:
        test(args)


if __name__ == "__main__":
    main()
